package com.flow.framework.cipher.algorithm;

import com.flow.framework.cipher.algorithm.version.Sm4Version;
import com.flow.framework.common.error.SystemErrorCode;
import com.flow.framework.common.exception.CheckedException;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.codec.binary.Base64;
import org.bouncycastle.jce.provider.BouncyCastleProvider;

import javax.crypto.Cipher;
import javax.crypto.KeyGenerator;
import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.StandardCharsets;
import java.security.Key;
import java.security.SecureRandom;
import java.security.Security;

/**
 * 国密4封装
 *
 * @author luoguopiao
 * @version 0.0.1
 * @date 2022/1/16
 */
@Slf4j
public class Sm4 {

    /**
     * 加密名称SM4
     */
    private static final String SM4 = "SM4";

    static {
        Security.addProvider(new BouncyCastleProvider());
    }

    /**
     * 获取 ECB cipher
     *
     * @param version version
     * @param mode    mode
     * @param key     key
     * @return
     */
    private static Cipher generateEcbCipher(Sm4Version version, int mode, byte[] key) {
        try {
            Cipher cipher = Cipher.getInstance(version.getCipherAlgorithm(), BouncyCastleProvider.PROVIDER_NAME);
            Key sm4Key = new SecretKeySpec(key, SM4);
            cipher.init(mode, sm4Key);
            return cipher;
        } catch (Exception e) {
            log.error("generate cipher error.", e);
            throw new CheckedException(SystemErrorCode.CIPHER_ERROR, "generate cipher error.", e);
        }
    }

    /**
     * 获取秘钥
     *
     * @param version version
     * @return
     */
    static String generateKey(Sm4Version version) {
        try {
            KeyGenerator kg = KeyGenerator.getInstance(SM4, BouncyCastleProvider.PROVIDER_NAME);
            kg.init(version.getKeySize(), new SecureRandom());
            byte[] keyBytes = kg.generateKey().getEncoded();
            return Base64.encodeBase64String(keyBytes);
        } catch (Exception e) {
            log.error("generate key error.", e);
            throw new CheckedException(SystemErrorCode.CIPHER_ERROR, "generate key error.", e);
        }
    }

    /**
     * 加密字符串
     *
     * @param version          version
     * @param content          content
     * @param encryptBase64Key encryptBase64Key
     * @return
     */
    static String encrypt(Sm4Version version, String content, String encryptBase64Key) {
        if (null == encryptBase64Key || null == content) {
            log.error("decrypt error. params can't be empty.");
            throw new CheckedException(SystemErrorCode.PARAMS_ERROR);
        }

        try {
            byte[] contentBytes = content.getBytes(StandardCharsets.UTF_8);
            byte[] resultBytes = encrypt(version, contentBytes, encryptBase64Key);
            return Base64.encodeBase64String(resultBytes);
        } catch (Exception e) {
            log.error("encrypt error.", e);
            throw new CheckedException(SystemErrorCode.CIPHER_ERROR, "encrypt error.", e);
        }
    }

    /**
     * 加密字符串
     *
     * @param version          version
     * @param contentBytes     contentBytes
     * @param encryptBase64Key encryptBase64Key
     * @return
     */
    static byte[] encrypt(Sm4Version version, byte[] contentBytes, String encryptBase64Key) {
        if (null == encryptBase64Key || null == contentBytes) {
            log.error("decrypt error. params can't be empty.");
            throw new CheckedException(SystemErrorCode.PARAMS_ERROR);
        }

        try {
            byte[] keyBytes = Base64.decodeBase64(encryptBase64Key);
            Cipher cipher = generateEcbCipher(version, Cipher.ENCRYPT_MODE, keyBytes);
            return AlgorithmUtil.cipherDoFinal(contentBytes, cipher, version.getEncryptBlockSize());
        } catch (Exception e) {
            log.error("encrypt error.", e);
            throw new CheckedException(SystemErrorCode.CIPHER_ERROR, "encrypt error.", e);
        }
    }

    /**
     * 解密字符串
     *
     * @param version          version
     * @param content          content
     * @param decryptBase64Key decryptBase64Key
     * @return
     */
    static String decrypt(Sm4Version version, String content, String decryptBase64Key) {
        if (null == decryptBase64Key || null == content) {
            log.error("decrypt error. params can't be empty.");
            throw new CheckedException(SystemErrorCode.PARAMS_ERROR);
        }
        try {
            byte[] contentBytes = Base64.decodeBase64(content);
            byte[] resultBytes = decrypt(version, contentBytes, decryptBase64Key);
            return new String(resultBytes, StandardCharsets.UTF_8);
        } catch (Exception e) {
            log.error("decrypt error.", e);
            throw new CheckedException(SystemErrorCode.CIPHER_ERROR, "decrypt error.", e);
        }
    }

    /**
     * 解密字符串
     *
     * @param version       version
     * @param contentBytes  contentBytes
     * @param decryptHexKey decryptHexKey
     * @return
     */
    static byte[] decrypt(Sm4Version version, byte[] contentBytes, String decryptHexKey) {
        if (null == decryptHexKey || null == contentBytes) {
            log.error("decrypt error. params can't be empty.");
            throw new CheckedException(SystemErrorCode.PARAMS_ERROR);
        }
        try {
            byte[] keyBytes = Base64.decodeBase64(decryptHexKey);
            Cipher cipher = generateEcbCipher(version, Cipher.DECRYPT_MODE, keyBytes);
            return AlgorithmUtil.cipherDoFinal(contentBytes, cipher, version.getDecryptBlockSize());
        } catch (Exception e) {
            log.error("decrypt error.", e);
            throw new CheckedException(SystemErrorCode.CIPHER_ERROR, "decrypt error.", e);
        }
    }
}
